import os
import itertools
import torch
import triton
import triton.language as tl
import triton.language.extra.libdevice as tldevice
import argparse
import numpy as np

FLA_SINGLE_THREAD = os.getenv("TRITON_PRECISION_NPU_SINGLE_THREAD", '0') == '1'

def single_wrapper(grid, grid_size, kernel):   
    if not FLA_SINGLE_THREAD:
        kwargs = {'USE_PIDS': False}
        for dim in range(grid_size):
            kwargs[f"PID{dim}"] = 0
        kernel(grid, kwargs)
        return
    assert isinstance(grid, tuple)
    
    grid_ranges = []
    for dim in grid:
        grid_ranges.append(range(dim))
    single_grid = (1,) * grid_size

    for combination in itertools.product(*grid_ranges):
        kwargs = {'USE_PIDS': True}
        for idx, dim in enumerate(combination):
            kwargs[f"PID{idx}"] = dim
        kernel(single_grid, kwargs)

if os.environ.get('FLA_USE_FAST_OPS', '0') == '1':
    exp = tldevice.fast_expf
    exp2 = tldevice.exp2
    log = tldevice.fast_logf
    log2 = tldevice.fast_log2f
else:
    exp = tl.exp
    exp2 = tl.math.exp2
    log = tl.log
    log2 = tl.log2

@triton.heuristics({
    'HAS_SCALE': lambda args: args['scale'] is not None,
    'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
})
@triton.jit(do_not_specialize=['T', 'PID0', 'PID1'])
def chunk_local_cumsum_scalar_kernel(
    s,
    o,
    scale,
    cu_seqlens,
    chunk_indices,
    T,
    B: tl.constexpr,
    H: tl.constexpr,
    BT: tl.constexpr,
    REVERSE: tl.constexpr,
    HAS_SCALE: tl.constexpr,
    IS_VARLEN: tl.constexpr,
    HEAD_FIRST: tl.constexpr,
    USE_PIDS: tl.constexpr,
    PID0,
    PID1
):
    if USE_PIDS:
        i_t, i_bh = PID0, PID1
    else:
        i_t, i_bh = tl.program_id(0), tl.program_id(1)
    i_b, i_h = i_bh // H, i_bh % H
    if IS_VARLEN:
        i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
        bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
        T = eos - bos
    else:
        bos, eos = i_b * T, i_b * T + T

    if HEAD_FIRST:
        p_s = tl.make_block_ptr(s + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
        p_o = tl.make_block_ptr(o + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
    else:
        p_s = tl.make_block_ptr(s + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
        p_o = tl.make_block_ptr(o + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
    # [BT]
    b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32)
    b_o = tl.cumsum(b_s, axis=0)
    if REVERSE:
        b_z = tl.sum(b_s, axis=0)
        b_o = -b_o + b_z[None] + b_s
    if HAS_SCALE:
        b_o *= scale
    tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,))

def test_chunk_local_cumsum_scalar_kernel(output_file):
    # 设置随机种子以确保可重复性
    torch.manual_seed(42)

    # 定义参数
    B = 8        # 批量大小
    T = 128       # 序列长度
    H = 8        # 头的数量
    BT = 64      # block大小 for T

    # 生成随机输入张量
    device = 'npu'  # 使用NPU设备
    dtype = torch.float32  # 使用float32以匹配内核的内部类型

    # 输入张量
    s = torch.randn(B, T, H, dtype=dtype, device=device)
    o = torch.randn(B, T, H, dtype=dtype, device=device)
    scale = None 
    cu_seqlens=None
    chunk_indices=None
    # 计算网格大小
    num_blocks_t = triton.cdiv(T, BT)
    num_blocks_h = B*H
    grid = (num_blocks_t, num_blocks_h)

    # 启用功能标志
    REVERSE = False
    HEAD_FIRST = False

    # 调用内核函数
    def kernel(cur_grid, pid_args):
        chunk_local_cumsum_scalar_kernel[cur_grid](
            s, o, scale,
            cu_seqlens, chunk_indices, T=T,
            B=B, H=H, BT=BT,
            REVERSE=REVERSE, HEAD_FIRST=HEAD_FIRST,
            **pid_args
        )
    single_wrapper(grid, 2, kernel)

    o_numpy = o.cpu().detach().numpy()
    np.savetxt(output_file, o_numpy.reshape(-1, o_numpy.shape[-1]))

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Test Chunk Local Cumsum Scalar Kernel')
    parser.add_argument('--output', type=str, default='default_output.txt', 
                        help='Output file name (default: default_output.txt)')
    args = parser.parse_args()
    test_chunk_local_cumsum_scalar_kernel(args.output)
    
